{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gUl_qfOR8JV6"
      },
      "source": [
        "##Setup\n",
        "\n",
        "You will need to make a copy of this notebook in your Google Drive before you can edit the homework files. You can do so with **File &rarr; Save a copy in Drive**."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "iizPcHAp8LnA"
      },
      "outputs": [],
      "source": [
        "#@title mount your Google Drive\n",
        "#@markdown Your work will be stored in a folder called `cs285_f2021` by default to prevent Colab instance timeouts from deleting your edits.\n",
        "\n",
        "import os\n",
        "from google.colab import drive\n",
        "drive.mount('/content/gdrive')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "cellView": "form",
        "id": "nAb10wnb8N0m"
      },
      "outputs": [],
      "source": [
        "#@title set up mount symlink\n",
        "\n",
        "DRIVE_PATH = '/content/gdrive/My\\ Drive/cs285_f2021'\n",
        "DRIVE_PYTHON_PATH = DRIVE_PATH.replace('\\\\', '')\n",
        "if not os.path.exists(DRIVE_PYTHON_PATH):\n",
        "  %mkdir $DRIVE_PATH\n",
        "\n",
        "## the space in `My Drive` causes some issues,\n",
        "## make a symlink to avoid this\n",
        "SYM_PATH = '/content/cs285_f2021'\n",
        "if not os.path.exists(SYM_PATH):\n",
        "  !ln -s $DRIVE_PATH $SYM_PATH"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "gtS9-WsD8QVr"
      },
      "outputs": [],
      "source": [
        "#@title apt install requirements\n",
        "\n",
        "#@markdown Run each section with Shift+Enter\n",
        "\n",
        "#@markdown Double-click on section headers to show code.\n",
        "\n",
        "!apt update \n",
        "!apt install -y --no-install-recommends \\\n",
        "        build-essential \\\n",
        "        curl \\\n",
        "        git \\\n",
        "        gnupg2 \\\n",
        "        make \\\n",
        "        cmake \\\n",
        "        ffmpeg \\\n",
        "        swig \\\n",
        "        libz-dev \\\n",
        "        unzip \\\n",
        "        zlib1g-dev \\\n",
        "        libglfw3 \\\n",
        "        libglfw3-dev \\\n",
        "        libxrandr2 \\\n",
        "        libxinerama-dev \\\n",
        "        libxi6 \\\n",
        "        libxcursor-dev \\\n",
        "        libgl1-mesa-dev \\\n",
        "        libgl1-mesa-glx \\\n",
        "        libglew-dev \\\n",
        "        libosmesa6-dev \\\n",
        "        lsb-release \\\n",
        "        ack-grep \\\n",
        "        patchelf \\\n",
        "        wget \\\n",
        "        xpra \\\n",
        "        xserver-xorg-dev \\\n",
        "        xvfb \\\n",
        "        python-opengl \\\n",
        "        ffmpeg > /dev/null 2>&1\n",
        "\n",
        "!pip install opencv-python==3.4.0.12"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "VcKGekJN80NO"
      },
      "outputs": [],
      "source": [
        "#@title download mujoco\n",
        "\n",
        "MJC_PATH = '{}/mujoco'.format(SYM_PATH)\n",
        "if not os.path.exists(MJC_PATH):\n",
        "  %mkdir $MJC_PATH\n",
        "%cd $MJC_PATH\n",
        "if not os.path.exists(os.path.join(MJC_PATH, 'mujoco200')):\n",
        "  !wget -q https://www.roboti.us/download/mujoco200_linux.zip\n",
        "  !unzip -q mujoco200_linux.zip\n",
        "  %mv mujoco200_linux mujoco200\n",
        "  %rm mujoco200_linux.zip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "cellView": "form",
        "id": "NTiH9f9y82F_"
      },
      "outputs": [],
      "source": [
        "#@title update mujoco paths\n",
        "\n",
        "import os\n",
        "\n",
        "os.environ['LD_LIBRARY_PATH'] += ':{}/mujoco200/bin'.format(MJC_PATH)\n",
        "os.environ['MUJOCO_PY_MUJOCO_PATH'] = '{}/mujoco200'.format(MJC_PATH)\n",
        "os.environ['MUJOCO_PY_MJKEY_PATH'] = '{}/mjkey.txt'.format(MJC_PATH)\n",
        "\n",
        "## installation on colab does not find *.so files\n",
        "## in LD_LIBRARY_PATH, copy over manually instead\n",
        "!cp $MJC_PATH/mujoco200/bin/*.so /usr/lib/x86_64-linux-gnu/"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A0kPh99l87q0"
      },
      "source": [
        "Ensure your `mjkey.txt` is in /content/cs285_f2021/mujoco before this step"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "X-LoOdZg84pI"
      },
      "outputs": [],
      "source": [
        "#@title clone and install mujoco-py\n",
        "\n",
        "%cd $MJC_PATH\n",
        "if not os.path.exists('mujoco-py'):\n",
        "  !git clone https://github.com/openai/mujoco-py.git\n",
        "%cd mujoco-py\n",
        "%pip install -e .\n",
        "\n",
        "## cythonize at the first import\n",
        "import mujoco_py"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "-XcwBiBN8-Fg"
      },
      "outputs": [],
      "source": [
        "#@title clone homework repo\n",
        "#@markdown Note that this is the same codebase from homework 1,\n",
        "#@markdown so you may need to move your old `homework_fall2021`\n",
        "#@markdown folder in order to clone the repo again.\n",
        "\n",
        "#@markdown **Don't delete your old work though!**\n",
        "#@markdown You will need it for this assignment.\n",
        "\n",
        "%cd $SYM_PATH\n",
        "!git clone https://github.com/berkeleydeeprlcourse/homework_fall2021.git\n",
        "%cd homework_fall2021/hw5\n",
        "%pip install -r requirements_colab.txt -f https://download.pytorch.org/whl/torch_stable.html\n",
        "%pip install -e ."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "cellView": "both",
        "id": "g5xIOIpW8_jC"
      },
      "outputs": [],
      "source": [
        "#@title set up virtual display\n",
        "\n",
        "from pyvirtualdisplay import Display\n",
        "\n",
        "display = Display(visible=0, size=(1400, 900))\n",
        "display.start()\n",
        "\n",
        "# For later\n",
        "from cs285.infrastructure.colab_utils import (\n",
        "    wrap_env,\n",
        "    show_video\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "2rsWAWaK9BVp"
      },
      "outputs": [],
      "source": [
        "#@title test virtual display\n",
        "\n",
        "#@markdown If you see a video of a four-legged ant fumbling about, setup is complete!\n",
        "\n",
        "import gym\n",
        "import matplotlib\n",
        "matplotlib.use('Agg')\n",
        "\n",
        "env = wrap_env(gym.make(\"Ant-v2\"))\n",
        "\n",
        "observation = env.reset()\n",
        "for i in range(10):\n",
        "    env.render(mode='rgb_array')\n",
        "    obs, rew, term, _ = env.step(env.action_space.sample() ) \n",
        "    if term:\n",
        "      break;\n",
        "            \n",
        "env.close()\n",
        "print('Loading video...')\n",
        "show_video()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QizpiHDh9Fwk"
      },
      "source": [
        "## Editing Code\n",
        "\n",
        "To edit code, click the folder icon on the left menu. Navigate to the corresponding file (`cs285_f2021/...`). Double click a file to open an editor. There is a timeout of about ~12 hours with Colab while it is active (and less if you close your browser window). We sync your edits to Google Drive so that you won't lose your work in the event of an instance timeout, but you will need to re-mount your Google Drive and re-install packages with every new instance."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Nii6qk2C9Ipk"
      },
      "source": [
        "## Run Exploration or Exploitation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "4t7FUeEG9Dkf"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import time\n",
        "\n",
        "from cs285.infrastructure.rl_trainer import RL_Trainer\n",
        "from cs285.agents.explore_or_exploit_agent import ExplorationOrExploitationAgent\n",
        "from cs285.infrastructure.dqn_utils import get_env_kwargs, PiecewiseSchedule, ConstantSchedule\n",
        "\n",
        "%load_ext autoreload\n",
        "%autoreload 2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "cellView": "both",
        "id": "2fXlzARJ9i-t"
      },
      "outputs": [],
      "source": [
        "#@title runtime arguments\n",
        "\n",
        "class Args:\n",
        "\n",
        "  def __getitem__(self, key):\n",
        "    return getattr(self, key)\n",
        "\n",
        "  def __setitem__(self, key, val):\n",
        "    setattr(self, key, val)\n",
        "\n",
        "  def __contains__(self, key):\n",
        "    return hasattr(self, key)\n",
        "\n",
        "  env_name = \"PointmassEasy-v0\" #@param [\"PointmassEasy-v0\", \"PointmassMedium-v0\", \"PointmassHard-v0\", \"PointmassVeryHard-v0\"]\n",
        "  exp_name = 'temp'#@param {type: \"string\"}\n",
        "\n",
        "  #@markdown batches and steps\n",
        "  batch_size = 256 #@param {type: \"integer\"}\n",
        "  eval_batch_size = 1000 #@param {type: \"integer\"}\n",
        "\n",
        "  #@exploration hyperparameters\n",
        "  use_rnd = False #@param {type: \"boolean\"}\n",
        "  unsupervised_exploration = False #@param {type: \"boolean\"}\n",
        "  num_exploration_steps = 10000 #@param {type: \"integer\"}\n",
        "  num_exploration_steps = 10000 #@param {type: \"integer\"}\n",
        "\n",
        "  #@offline training hyperparameters\n",
        "  offline_exploitation = False #@param {type: \"boolean\"}\n",
        "  cql_alpha = 0.0 #@param {type: \"raw\"}\n",
        "\n",
        "  #@reward shifting hyperparameters\n",
        "  exploit_rew_shift = 0.0 #@param {type: \"raw\"}\n",
        "  exploit_rew_scale = 1.0 #@param {type: \"raw\"}\n",
        "\n",
        "  #@exploration model hyperparameters\n",
        "  rnd_output_size = 5 #@param {type: \"integer\"}\n",
        "  rnd_n_layers = 2 #@param {type: \"integer\"}\n",
        "  rnd_size = 400 #@param {type: \"integer\"}\n",
        "\n",
        "  #@experiment hyperparameters\n",
        "  seed = 1 #@param {type: \"integer\"}\n",
        "  no_gpu = False #@param {type: \"boolean\"}\n",
        "  which_gpu = 0 #@param {type: \"integer\"}\n",
        "  scalar_log_freq = 1000 #@param {type: \"integer\"}\n",
        "  save_params = False #@param {type: \"boolean\"}\n",
        "\n",
        "\n",
        "args = Args()\n",
        "\n",
        "## ensure compatibility with hw1 code\n",
        "args['train_batch_size'] = args['batch_size']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "8vi1S18pGi7o"
      },
      "outputs": [],
      "source": [
        "###### HARDCODED DETAILS, DO NOT CHANGE ######\n",
        "\n",
        "args['video_log_freq'] = -1 # Not used\n",
        "args['exploit_weight_schedule'] = ConstantSchedule(1.0)\n",
        "args['num_agent_train_steps_per_iter'] = 1\n",
        "args['num_critic_updates_per_agent_update'] = 1\n",
        "args['learning_starts'] = 2000\n",
        "args['num_timesteps'] = 50000\n",
        "args['double_q'] = True\n",
        "args['eps'] = 0.2\n",
        "\n",
        "if args['env_name']=='PointmassEasy-v0':\n",
        "  args['ep_len']=50\n",
        "if args['env_name']=='PointmassMedium-v0':\n",
        "  args['ep_len']=150\n",
        "if args['env_name']=='PointmassHard-v0':\n",
        "  args['ep_len']=100\n",
        "if args['env_name']=='PointmassVeryHard-v0':\n",
        "  args['ep_len']=200\n",
        "\n",
        "if args['use_rnd']:\n",
        "  args['explore_weight_schedule'] = PiecewiseSchedule([(0,1), (args['num_exploration_steps'], 0)], outside_value=0.0)\n",
        "else:\n",
        "  args['explore_weight_schedule'] = ConstantSchedule(0.0)\n",
        "\n",
        "if args['unsupervised_exploration']:\n",
        "  args['explore_weight_schedule'] = ConstantSchedule(1.0)\n",
        "  args['exploit_weight_schedule'] = ConstantSchedule(0.0)\n",
        "  \n",
        "  if not args['use_rnd']:\n",
        "    args['learning_starts'] = args['num_exploration_steps']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T0cJlp6s-ogO"
      },
      "outputs": [],
      "source": [
        "#@title create directories for logging\n",
        "\n",
        "data_path = '''/content/cs285_f2021/''' \\\n",
        "        '''homework_fall2021/hw5/data'''\n",
        "\n",
        "if not (os.path.exists(data_path)):\n",
        "    os.makedirs(data_path)\n",
        "\n",
        "logdir = 'hw5_' + args.exp_name + '_' + args.env_name + '_' + time.strftime(\"%d-%m-%Y_%H-%M-%S\")\n",
        "logdir = os.path.join(data_path, logdir)\n",
        "args['logdir'] = logdir\n",
        "if not(os.path.exists(logdir)):\n",
        "    os.makedirs(logdir)\n",
        "\n",
        "print(\"LOGGING TO: \", logdir)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "I525KFRN-42s"
      },
      "outputs": [],
      "source": [
        "#@title Define Exploration Agent\n",
        "\n",
        "class Q_Trainer(object):\n",
        "\n",
        "    def __init__(self, params):\n",
        "        self.params = params\n",
        "\n",
        "        train_args = {\n",
        "            'num_agent_train_steps_per_iter': params['num_agent_train_steps_per_iter'],\n",
        "            'num_critic_updates_per_agent_update': params['num_critic_updates_per_agent_update'],\n",
        "            'train_batch_size': params['batch_size'],\n",
        "            'double_q': params['double_q'],\n",
        "        }\n",
        "\n",
        "        env_args = get_env_kwargs(params['env_name'])\n",
        "\n",
        "        for k, v in env_args.items():\n",
        "          params[k] = v\n",
        "        for k, v in train_args.items():\n",
        "          params[k] = v\n",
        "\n",
        "        self.agent_params = params\n",
        "\n",
        "        self.params['agent_class'] = ExplorationOrExploitationAgent\n",
        "        self.params['agent_params'] = self.agent_params\n",
        "        self.params['train_batch_size'] = params['batch_size']\n",
        "        self.params['env_wrappers'] = self.agent_params['env_wrappers']\n",
        "\n",
        "        self.rl_trainer = RL_Trainer(self.params)\n",
        "\n",
        "    def run_training_loop(self):\n",
        "        self.rl_trainer.run_training_loop(\n",
        "            self.agent_params['num_timesteps'],\n",
        "            collect_policy = self.rl_trainer.agent.actor,\n",
        "            eval_policy = self.rl_trainer.agent.actor,\n",
        "            )\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "wF4LSRGn-_Cv"
      },
      "outputs": [],
      "source": [
        "#@title run training\n",
        "\n",
        "trainer = Q_Trainer(args)\n",
        "trainer.run_training_loop()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_kTH-tXkI-B-"
      },
      "outputs": [],
      "source": [
        "#@markdown You can visualize your runs with tensorboard from within the notebook\n",
        "\n",
        "## requires tensorflow==2.3.0\n",
        "%load_ext tensorboard\n",
        "%tensorboard --logdir /content/cs285_f2021/homework_fall2021/hw5/data/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BwF7tQPQ66hB"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "run_hw5_expl.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
